package cb

import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}
import org.apache.spark.ml.linalg.{SparseVector, Vectors}
import org.apache.spark.sql.{Row, SparkSession}

/**
  * Created by hunter.coder 涛哥  
  * 2019/5/5 11:40
  * 交流qq:657270652
  * Version: 1.0
  * 更多学习资料：https://blog.csdn.net/coderblack/
  * Description:
  **/
object ItemSimilary {


  /**
    * 欧氏距离相似度算法
    *
    * @param vec1
    * @param vec2
    * @return
    */
  def ogldSimilarity(vec1: SparseVector, vec2: SparseVector): Double = {

    val distance = Math.sqrt(Vectors.sqdist(vec1, vec2))

    1 / (1 + distance)
  }


  /**
    * 余弦相似度算法
    *
    * @param vec1
    * @param vec2
    * @return
    */
  def cosineSimilarity(vec1: SparseVector, vec2: SparseVector): Double = {

    val mo1 = vec1.values.map(Math.pow(_, 2)).sum
    val mo2 = vec2.values.map(Math.pow(_, 2)).sum
    val fenmu = Math.sqrt(mo1 * mo2)
    // (4,[1,2],[2.3,4.5])
    // (4,[0,1,3],[2.0,2.5,1.8])

    val indices1 = vec1.indices
    val indices2 = vec2.indices

    // 求索引的交集
    val commonIndices = indices1.toSet.intersect(indices2.toSet)

    var fenzi = 0.0

    // 遍历交集中的每一个索引值，去搜寻该索引值在向量索引数组中的脚标位置
    for (elem <- commonIndices) {
      // 取索引值在向量中的脚标
      val i = indices1.indexOf(elem)
      val j = indices2.indexOf(elem)
      // 去索引值所映射的特征值
      fenzi += vec1.values(i) * vec2.values(j)
    }

    fenzi / fenmu
  }


  /**
    * 余弦相似度算法2
    *
    * @param vec1
    * @param vec2
    * @return
    */
  def cosineSimilarity2(vec1: SparseVector, vec2: SparseVector): Double = {

    val mo1 = vec1.values.map(Math.pow(_, 2)).sum
    val mo2 = vec2.values.map(Math.pow(_, 2)).sum
    val fenmu = Math.sqrt(mo1 * mo2)

    // [0,   2.0, 4.5, 0   ]
    // [2.0, 2.5, 0,   1.8]
    val fenzi = vec1.toDense.values.zip(vec2.toDense.values).map(tp => tp._1 * tp._2).sum

    fenzi / fenmu
  }


  def main(args: Array[String]): Unit = {

    Logger.getLogger("org").setLevel(Level.WARN)
    val spark = SparkSession.builder().appName("item-similary").master("local").getOrCreate()

    import spark.implicits._
    import org.apache.spark.sql.functions._


    val df = spark.read.option("header", "true").csv("G:\\data_shark\\doit_recommender\\src\\test\\java\\cb_rec\\item.profile.dat")

    /**
      * +---+----+----+----+--------------------------------------------------------------------------+
      * |pid|cat1|cat2|cat3|kwds                                                                      |
      * +---+----+----+----+--------------------------------------------------------------------------+
      * |p01|电子  |数码  |手机  |Apple iPhone XR (A2108) 128GB 黑色 移动 联通 电信 4G手机 双卡双待                       |
      * |p02|电子  |数码  |手环  |小米 红米Redmi Note7 幻彩渐变 AI双摄 4GB+64GB 梦幻蓝 全网通4G 双卡双待 水滴 全面屏 拍照 游戏 智能 手机     |
      * |p03|电子  |数码  |手环  |华为 HUAWEI P30 超感光 徕卡三摄 麒麟 980AI 智能芯片 全面屏 屏内指纹 版 手机 8GB 128GB 亮黑色 全网通 双4G手机|
      * |p04|电子  |数码  |手机  |中兴 全面屏 超长 待机 AI双摄 性价比 流量无限                                                |
      * |p05|食品饮料|休闲食品|坚果炒货|三只松鼠 坚果 炒货 孕妇 坚果 每日 坚果 干果 零食 奶油味 夏威夷果 160g/袋                              |
      * +---+----+----+----+--------------------------------------------------------------------------+
      */

    // 将原始数据向量化  DSL风格的语法： 像编程调方法一样的写sql
    //df.select(concat_ws(" ",col("cat1"),col("cat2")))
    //df.select(concat_ws(" ",$"cat1",$"cat2",$"cat3",$"kwds"))
    val df_concate = df.select('pid, concat_ws(" ", 'cat1, 'cat2, 'cat3, 'kwds).as("kwds"))


    /**
      * +---+-----------------------------------------------------------------------------------+
      * |pid|kwds                                         |
      * +---+-----------------------------------------------------------------------------------+
      * |p01|电子 数码 手机 Apple iPhone XR (A2108) 128GB 黑色 移动 联通 电信 4G手机 双卡双待                       |
      * |p02|电子 数码 手环 小米 红米Redmi Note7 幻彩渐变 AI双摄 4GB+64GB 梦幻蓝 全网通4G 双卡双待 水滴 全面屏 拍照 游戏 智能 手机     |
      * |p03|电子 数码 手环 华为 HUAWEI P30 超感光 徕卡三摄 麒麟 980AI 智能芯片 全面屏 屏内指纹 版 手机 8GB 128GB 亮黑色 全网通 双4G手机|
      * |p04|电子 数码 手机 中兴 全面屏 超长 待机 AI双摄 性价比 流量无限                                                |
      * |p05|食品饮料 休闲食品 坚果炒货 三只松鼠 坚果 炒货 孕妇 坚果 每日 坚果 干果 零食 奶油味 夏威夷果 160g/袋                        |
      * +---+-----------------------------------------------------------------------------------+
      */

    // 分词  《注意，如果原始数据并没有进行中文分词处理，就不要用此tokenizer来分词，而应该用各类高性能的中文分词器：IKAnalyzer或者Hanlp》
    val tokenizer = new Tokenizer().setInputCol("kwds").setOutputCol("words")
    // df_words也就是一个dataframe，只不过它的words字段类型为 字符串数组
    val df_words = tokenizer.transform(df_concate)


    // 将词数组向量化（使用tfhashing 哈希映射）
    val tf = new HashingTF().setInputCol("words").setOutputCol("tf_vec").setNumFeatures(10000)
    val df_tfVec = tf.transform(df_words).drop("kwds", "words")
    // 稀疏向量可以转稠密向量  .toDense
    df_tfVec.rdd.map(row => row.getAs[SparseVector]("tf_vec").toDense).take(10).foreach(println)

    // 转成tf-idf向量
    val idf = new IDF().setInputCol("tf_vec").setOutputCol("tfidf_vec")
    val tfidfModel = idf.fit(df_tfVec)
    val df_tfidfVec = tfidfModel.transform(df_tfVec).drop("tf_vec")

    df_tfidfVec.show(10, false)

    /**
      * +---+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
      * |pid|tfidf_vec                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               |
      * +---+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
      * |p01|(10000,[1827,2906,2997,3748,4152,5003,5231,6330,7167,7794,7807,8454,8975,9477],[1.0986122886681098,0.1823215567939546,1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,0.1823215567939546,1.0986122886681098,0.1823215567939546,1.0986122886681098,1.0986122886681098,0.6931471805599453,0.6931471805599453])                                                                                                                                             |
      * |p02|(10000,[1249,1396,2612,2906,3175,3526,4368,4948,5089,5398,6330,6334,6343,6469,6697,7794,9376,9477],[1.0986122886681098,1.0986122886681098,0.4054651081081644,0.1823215567939546,1.0986122886681098,0.6931471805599453,1.0986122886681098,1.0986122886681098,0.6931471805599453,1.0986122886681098,0.1823215567939546,1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,0.1823215567939546,1.0986122886681098,0.6931471805599453])                                             |
      * |p03|(10000,[54,611,1189,1643,1704,2612,2822,2906,3526,3528,3803,4605,6062,6330,6378,7794,8187,8928,8975,9542],[1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,0.4054651081081644,1.0986122886681098,0.1823215567939546,0.6931471805599453,1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,0.1823215567939546,1.0986122886681098,0.1823215567939546,1.0986122886681098,1.0986122886681098,0.6931471805599453,1.0986122886681098])|
      * |p04|(10000,[1084,2612,2906,5089,6330,6432,7636,7794,8772,9436],[1.0986122886681098,0.4054651081081644,0.1823215567939546,0.6931471805599453,0.1823215567939546,1.0986122886681098,1.0986122886681098,0.1823215567939546,1.0986122886681098,1.0986122886681098])                                                                                                                                                                                                                                             |
      * |p05|(10000,[65,443,2736,3413,3637,3928,3945,4632,5932,7422,7489,8570,9371],[1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098,3.295836866004329,1.0986122886681098,1.0986122886681098])                                                                                                                                                                         |
      * +---+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
      */

    // 求各item之间的相似度（欧几里得算法/余弦相似度算法）
    val df_join = df_tfidfVec.select('pid as "pid1", 'tfidf_vec as "vec1").crossJoin(df_tfidfVec.select('pid as "pid2", 'tfidf_vec as "vec2"))
      .filter('pid1.lt('pid2))

    df_join.show(10, false)


    df_join.map(row => {

      val pid1 = row.getAs[String]("pid1")
      val pid2 = row.getAs[String]("pid2")

      val vec1 = row.getAs[SparseVector]("vec1")
      val vec2 = row.getAs[SparseVector]("vec2")

      // 求欧式距离
      val osim = ogldSimilarity(vec1, vec2)
      // 求余弦距离
      val cosim = cosineSimilarity(vec1, vec2)

      (pid1, pid2, osim, cosim)
    }).toDF("pid1", "pid2", "osim", "cosim")
      .write.parquet("G:\\data_shark\\doit_recommender\\src\\test\\data\\item_sim\\")


    spark.close()
  }

}
